"""Contains utility functions for calculating activations and connectivity. Adapted code is acknowledged in comments"""

import numpy as np
import torch
import torch.nn as nn
import os


import time
import copy
import math
import sklearn
import random 

import scipy.spatial     as ss

from math                 import log, sqrt
from scipy                import stats
from sklearn              import manifold
from scipy.special        import *
from sklearn.neighbors    import NearestNeighbors


visualisation = {}

"""
hook_fn(), activations(), and get_all_layers() adapted from: https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/
"""

#### Hook Function
def hook_fn(m, i, o):
    visualisation[m] = o 


### Create forward hooks to all layers which will collect activation state
def get_all_layers(model, hook_handles, item_key):
    with torch.no_grad():
        # for module_idx, module in enumerate(net.shared.modules()):
        #     if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)):
        #         if(module_idx == item_key):
        #             hook_handles.append(module.register_forward_hook(hook_fn))

        for name, module in enumerate(model.named_modules()):
            if name == item_key:
                hook_handles.append(module[1].register_forward_hook(hook_fn))


### Process and record all of the activations for the given pair of layers
def activations(x_input, model, cuda, item_key):
    temp_op       = None
    temp_label_op = None

    parents_op  = None
    labels_op   = None

    handles     = []

    get_all_layers(model, handles, item_key)

    with torch.no_grad():
        model(x_input.cuda())

        if temp_op is None:
            temp_op        = visualisation[list(visualisation.keys())[0]].cpu().numpy()
        else:
            temp_op        = np.vstack((visualisation[list(visualisation.keys())[0]].cpu().numpy(), temp_op))
        
    parents_op = copy.deepcopy(temp_op)
    # Remove all hook handles
    for handle in handles:
        handle.remove()    
    
    del visualisation[list(visualisation.keys())[0]]

    return parents_op

        
    